import torch
from torch import optim

from .base_optimizer import BaseOptimizerPlugin


class AdamWOptimizerPlugin(BaseOptimizerPlugin):
    """AdamW优化器插件"""
    
    def create_optimizer(self, model, config):
        """创建AdamW优化器实例"""
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=config.get('learning_rate', 5e-4),
            weight_decay=config.get('weight_decay', 0.0)
            # 使用与MiniMind相同的默认参数
        )
        return optimizer